--- external/stylegan/training/networks_stylegan2_terrain.py	2023-03-09 18:11:00.963870634 +0000
+++ external_reference/stylegan/training/networks_stylegan2_terrain.py	2023-03-09 18:07:58.572064994 +0000
@@ -20,6 +20,9 @@
 from torch_utils.ops import bias_act
 from torch_utils.ops import fma
 
+from external.gsn.models.discriminator import ConvDecoder
+from utils.utils import interpolate
+
 #----------------------------------------------------------------------------
 
 @misc.profiled_function
@@ -306,17 +309,23 @@
             self.noise_strength = torch.nn.Parameter(torch.zeros([]))
         self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
 
-    def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
-        assert noise_mode in ['random', 'const', 'none']
+    def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1, noise_input=None):
+        assert noise_mode in ['random', 'const', 'none', '3dnoise']
         in_resolution = self.resolution // self.up
-        misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
+        # CHANGED: layout SOAT noise may have different dimensions
+        # misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
         styles = self.affine(w)
 
         noise = None
         if self.use_noise and noise_mode == 'random':
-            noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
+            # CHANGED: layout SOAT noise may be spatially larger than self.resolution
+            noise = torch.randn([x.shape[0], 1, x.shape[2]*self.up, x.shape[3]*self.up], device=x.device) * self.noise_strength
+            # noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
         if self.use_noise and noise_mode == 'const':
             noise = self.noise_const * self.noise_strength
+        if self.use_noise and noise_mode == '3dnoise':
+            # CHANGED: support 3d projected noise input in upsampler
+            noise = interpolate(noise_input, x.shape[2]*self.up)
 
         flip_weight = (self.up == 1) # slightly faster
         x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
@@ -463,31 +472,47 @@
 
 #----------------------------------------------------------------------------
 
+# CHANGED: add support for truncated generator (upsampler)
 @persistence.persistent_class
 class SynthesisNetwork(torch.nn.Module):
     def __init__(self,
         w_dim,                      # Intermediate latent (W) dimensionality.
         img_resolution,             # Output image resolution.
         img_channels,               # Number of color channels.
+        input_resolution = 4,        # Input resolution for truncated generator
         channel_base    = 32768,    # Overall multiplier for the number of channels.
         channel_max     = 512,      # Maximum number of channels in any layer.
         num_fp16_res    = 4,        # Use FP16 for the N highest resolutions.
+        num_additional_feature_channels = 0, # Additional feature channels for input layer
+        default_noise_mode = 'random', 
         **block_kwargs,             # Arguments for SynthesisBlock.
     ):
         assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
         super().__init__()
         self.w_dim = w_dim
+        self.input_resolution = input_resolution
+        self.input_resolution_log2 = int(np.log2(input_resolution))
         self.img_resolution = img_resolution
         self.img_resolution_log2 = int(np.log2(img_resolution))
         self.img_channels = img_channels
         self.num_fp16_res = num_fp16_res
-        self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
+        self.block_resolutions = [2 ** i for i in range(self.input_resolution_log2, self.img_resolution_log2 + 1)]
         channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
         fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
 
+        self.input_channels = channels_dict[self.block_resolutions[0]]
+        self.num_additional_feature_channels = num_additional_feature_channels
+        self.default_noise_mode = default_noise_mode
+
         self.num_ws = 0
-        for res in self.block_resolutions:
+        for num_res, res in enumerate(self.block_resolutions):
+            if num_res == 0 and res > 4:
+                # for upsampler network, skip the first entry (used for input layer)
+                continue
             in_channels = channels_dict[res // 2] if res > 4 else 0
+            # CHANGED: concatenate additional feature channels as input
+            if num_res == 1:
+                in_channels += num_additional_feature_channels
             out_channels = channels_dict[res]
             use_fp16 = (res >= fp16_resolution)
             is_last = (res == self.img_resolution)
@@ -498,19 +523,27 @@
                 self.num_ws += block.num_torgb
             setattr(self, f'b{res}', block)
 
-    def forward(self, ws, **block_kwargs):
+    def forward(self, ws, x=None, img=None, **block_kwargs):
         block_ws = []
+        block_res = self.block_resolutions
+        if self.input_resolution > 4:
+            # skip input block for upsampler network
+            block_res = self.block_resolutions[1:]
+
         with torch.autograd.profiler.record_function('split_ws'):
             misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
             ws = ws.to(torch.float32)
             w_idx = 0
-            for res in self.block_resolutions:
+            for res in block_res:
                 block = getattr(self, f'b{res}')
                 block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
                 w_idx += block.num_conv
 
-        x = img = None
-        for res, cur_ws in zip(self.block_resolutions, block_ws):
+        if 'noise_mode' not in block_kwargs:
+            block_kwargs['noise_mode'] = self.default_noise_mode
+
+        # x = img = None
+        for res, cur_ws in zip(block_res, block_ws):
             block = getattr(self, f'b{res}')
             x, img = block(x, img, cur_ws, **block_kwargs)
         return img
@@ -523,6 +556,7 @@
 
 #----------------------------------------------------------------------------
 
+# CHANGED: add support for truncated generator (upsampler)
 @persistence.persistent_class
 class Generator(torch.nn.Module):
     def __init__(self,
@@ -531,6 +565,7 @@
         w_dim,                      # Intermediate latent (W) dimensionality.
         img_resolution,             # Output resolution.
         img_channels,               # Number of output color channels.
+        input_resolution    = 4,    # Input resolution for truncated generator (upsampler)
         mapping_kwargs      = {},   # Arguments for MappingNetwork.
         **synthesis_kwargs,         # Arguments for SynthesisNetwork.
     ):
@@ -540,13 +575,14 @@
         self.w_dim = w_dim
         self.img_resolution = img_resolution
         self.img_channels = img_channels
-        self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
+        self.input_resolution = input_resolution
+        self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, input_resolution=input_resolution, **synthesis_kwargs)
         self.num_ws = self.synthesis.num_ws
         self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
 
-    def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+    def forward(self, z, c, x=None, img=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
         ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
-        img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+        img = self.synthesis(ws, x=x, img=img, update_emas=update_emas, **synthesis_kwargs)
         return img
 
 #----------------------------------------------------------------------------
@@ -730,6 +766,7 @@
 
 #----------------------------------------------------------------------------
 
+# CHANGED: add reconstruction decoder to discriminator
 @persistence.persistent_class
 class Discriminator(torch.nn.Module):
     def __init__(self,
@@ -745,6 +782,7 @@
         block_kwargs        = {},       # Arguments for DiscriminatorBlock.
         mapping_kwargs      = {},       # Arguments for MappingNetwork.
         epilogue_kwargs     = {},       # Arguments for DiscriminatorEpilogue.
+        recon = True,
     ):
         super().__init__()
         self.c_dim = c_dim
@@ -774,6 +812,11 @@
         if c_dim > 0:
             self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
         self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+        self.recon = recon
+        if self.recon:
+            self.decoder = ConvDecoder(in_channel=channels_dict[4],
+                                       out_channel=img_channels, in_res=4,
+                                       out_res=img_resolution)
 
     def forward(self, img, c, update_emas=False, **block_kwargs):
         _ = update_emas # unused
@@ -785,8 +828,13 @@
         cmap = None
         if self.c_dim > 0:
             cmap = self.mapping(None, c)
+        if self.recon:
+            recon = self.decoder(x.float())
         x = self.b4(x, img, cmap)
-        return x
+        if self.recon:
+            return x, recon
+        else:
+            return x
 
     def extra_repr(self):
         return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
